#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2017/12/21 10:22
# @Author  : Aries
# @Site    : 
# @File    : image_tools.py.py
# @Software: PyCharm Community Edition

import numpy as np

import gdal

import config

def img2array(img_fn):
    """
    read geotiff img to array by gdal
    :param imgfn path of geotiff
    :return narray of geotiff
    """
    #img_name = _get_image_names(img_id)
    raster = gdal.Open(img_fn) #读取RGB波段图片
    array = raster.ReadAsArray()
    array = np.rollaxis(array, 0, 3) #roll axis for keras channel-last format
    return array

def readtiff(imgDir):
    """
    读取tiff和相关信息
    :param imgDir:
    :return:
    """
    # raster dem10m
    layer = gdal.Open(imgDir)
    gt = layer.GetGeoTransform()
    bands = layer.RasterCount
    #print bands
    #print gt
    return layer, bands, gt

def Val_raster(x, y, imgDir):
    layer, bands, gt = readtiff(imgDir)

    col=[]
    px = int((x - gt[0]) / gt[1])
    py =int((y - gt[3]) / gt[5])
    for j in range(bands):
        band = layer.GetRasterBand(j+1)
        data = band.ReadAsArray(px,py, 1, 1)
        #print data[0][0]
        col.append(data[0][0])
    return col